{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "This example requires the following dependencies to be installed:\n",
    "pip install lightly\n",
    "pip install \"scikit-learn>=1.7.1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install lightly"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "Note: The model and training settings do not follow the reference settings\n",
    "from the paper. The settings are chosen such that the example can easily be\n",
    "run on a small dataset with a single GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "from sklearn.cluster import KMeans\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from lightly.models import utils\n",
    "from lightly.models.modules.heads import (\n",
    "    SMoGPredictionHead,\n",
    "    SMoGProjectionHead,\n",
    "    SMoGPrototypes,\n",
    ")\n",
    "from lightly.models.modules.memory_bank import MemoryBankModule\n",
    "from lightly.transforms.smog_transform import SMoGTransform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SMoGModel(nn.Module):\n",
    "    def __init__(self, backbone):\n",
    "        super().__init__()\n",
    "\n",
    "        self.backbone = backbone\n",
    "        self.projection_head = SMoGProjectionHead(512, 2048, 128)\n",
    "        self.prediction_head = SMoGPredictionHead(128, 2048, 128)\n",
    "\n",
    "        self.backbone_momentum = copy.deepcopy(self.backbone)\n",
    "        self.projection_head_momentum = copy.deepcopy(self.projection_head)\n",
    "\n",
    "        utils.deactivate_requires_grad(self.backbone_momentum)\n",
    "        utils.deactivate_requires_grad(self.projection_head_momentum)\n",
    "\n",
    "        self.n_groups = 300\n",
    "        self.smog = SMoGPrototypes(\n",
    "            group_features=torch.rand(self.n_groups, 128), beta=0.99\n",
    "        )\n",
    "\n",
    "    def _cluster_features(self, features: torch.Tensor) -> torch.Tensor:\n",
    "        # clusters the features using sklearn\n",
    "        # (note: faiss is probably more efficient)\n",
    "        features = features.cpu().numpy()\n",
    "        kmeans = KMeans(self.n_groups).fit(features)\n",
    "        clustered = torch.from_numpy(kmeans.cluster_centers_).float()\n",
    "        clustered = torch.nn.functional.normalize(clustered, dim=1)\n",
    "        return clustered\n",
    "\n",
    "    def reset_group_features(self, memory_bank):\n",
    "        # see https://arxiv.org/pdf/2207.06167.pdf Table 7b)\n",
    "        features = memory_bank.bank\n",
    "        group_features = self._cluster_features(features)\n",
    "        self.smog.set_group_features(group_features)\n",
    "\n",
    "    def reset_momentum_weights(self):\n",
    "        # see https://arxiv.org/pdf/2207.06167.pdf Table 7b)\n",
    "        self.backbone_momentum = copy.deepcopy(self.backbone)\n",
    "        self.projection_head_momentum = copy.deepcopy(self.projection_head)\n",
    "        utils.deactivate_requires_grad(self.backbone_momentum)\n",
    "        utils.deactivate_requires_grad(self.projection_head_momentum)\n",
    "\n",
    "    def forward(self, x):\n",
    "        features = self.backbone(x).flatten(start_dim=1)\n",
    "        encoded = self.projection_head(features)\n",
    "        predicted = self.prediction_head(encoded)\n",
    "        return encoded, predicted\n",
    "\n",
    "    def forward_momentum(self, x):\n",
    "        features = self.backbone_momentum(x).flatten(start_dim=1)\n",
    "        encoded = self.projection_head_momentum(features)\n",
    "        return encoded"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 256"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8",
   "metadata": {},
   "outputs": [],
   "source": [
    "resnet = torchvision.models.resnet18()\n",
    "backbone = nn.Sequential(*list(resnet.children())[:-1])\n",
    "model = SMoGModel(backbone)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# memory bank because we reset the group features every 300 iterations\n",
    "memory_bank_size = 300 * batch_size\n",
    "memory_bank = MemoryBankModule(size=(memory_bank_size, 128))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11",
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = SMoGTransform(\n",
    "    crop_sizes=(32, 32),\n",
    "    crop_counts=(1, 1),\n",
    "    gaussian_blur_probs=(0.0, 0.0),\n",
    "    crop_min_scales=(0.2, 0.2),\n",
    "    crop_max_scales=(1.0, 1.0),\n",
    ")\n",
    "dataset = torchvision.datasets.CIFAR10(\n",
    "    \"datasets/cifar10\", download=True, transform=transform\n",
    ")\n",
    "# or create a dataset from a folder containing images or videos:\n",
    "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = torch.utils.data.DataLoader(\n",
    "    dataset,\n",
    "    batch_size=256,\n",
    "    shuffle=True,\n",
    "    drop_last=True,\n",
    "    num_workers=8,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13",
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(\n",
    "    model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-6\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14",
   "metadata": {},
   "outputs": [],
   "source": [
    "global_step = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Starting Training\")\n",
    "for epoch in range(10):\n",
    "    total_loss = 0\n",
    "    for batch_idx, batch in enumerate(dataloader):\n",
    "        (x0, x1) = batch[0]\n",
    "\n",
    "        if batch_idx % 2:\n",
    "            # swap batches every second iteration\n",
    "            x1, x0 = x0, x1\n",
    "\n",
    "        x0 = x0.to(device)\n",
    "        x1 = x1.to(device)\n",
    "\n",
    "        if global_step > 0 and global_step % 300 == 0:\n",
    "            # reset group features and weights every 300 iterations\n",
    "            model.reset_group_features(memory_bank=memory_bank)\n",
    "            model.reset_momentum_weights()\n",
    "        else:\n",
    "            # update momentum\n",
    "            utils.update_momentum(model.backbone, model.backbone_momentum, 0.99)\n",
    "            utils.update_momentum(\n",
    "                model.projection_head, model.projection_head_momentum, 0.99\n",
    "            )\n",
    "\n",
    "        x0_encoded, x0_predicted = model(x0)\n",
    "        x1_encoded = model.forward_momentum(x1)\n",
    "\n",
    "        # update group features and get group assignments\n",
    "        assignments = model.smog.assign_groups(x1_encoded)\n",
    "        group_features = model.smog.get_updated_group_features(x0_encoded)\n",
    "        logits = model.smog(x0_predicted, group_features, temperature=0.1)\n",
    "        model.smog.set_group_features(group_features)\n",
    "\n",
    "        loss = criterion(logits, assignments)\n",
    "\n",
    "        # use memory bank to periodically reset the group features with k-means\n",
    "        memory_bank(x0_encoded, update=True)\n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        global_step += 1\n",
    "        total_loss += loss.detach()\n",
    "\n",
    "    avg_loss = total_loss / len(dataloader)\n",
    "    print(f\"epoch: {epoch:>02}, loss: {avg_loss:.5f}\")"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "main_language": "python",
   "notebook_metadata_filter": "-all"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
